{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Multiplication\n",
    "\n",
    "A quick and easy example to start off with is to build a toy model which takes in two numbers, and outputs the result. Although the model doesn't accomplish anything significant the same techniques can be used to model and train much larger and complex networks.\n",
    "\n",
    "`Numpy` is seeded to allow deterministic results, this seeding has no relevance to the architecture or the training of the network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "import nengo\n",
    "import tensorflow as tf\n",
    "import nengo_dl\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from math import sqrt\n",
    "\n",
    "np.random.seed(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Architecture\n",
    "\n",
    "We connect two input nodes (`i_1`, `i_2`), both of which generate random numbers, to ensemble `a`. Then `a` is connected to a second ensemble `b`, which we probe for the output."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with nengo.Network() as net:\n",
    "    \n",
    "    \n",
    "    net.config[nengo.Ensemble].neuron_type = nengo.RectifiedLinear()\n",
    "    net.config[nengo.Ensemble].gain = nengo.dists.Choice([1])\n",
    "    net.config[nengo.Ensemble].bias = nengo.dists.Uniform(-1, 1)\n",
    "    net.config[nengo.Connection].synapse = None\n",
    "    \n",
    "    i_1 = nengo.Node(output=lambda t: np.random.random())\n",
    "    i_2 = nengo.Node(output=lambda t: np.random.random())\n",
    "\n",
    "    a = nengo.Ensemble(100, 2)\n",
    "    b = nengo.Ensemble(100, 1)\n",
    "    nengo.Connection(i_1, a[0])\n",
    "    nengo.Connection(i_2, a[1])\n",
    "    nengo.Connection(a, b, function=lambda x: [0])\n",
    "    \n",
    "    i_1_probe = nengo.Probe(i_1)\n",
    "    i_2_probe = nengo.Probe(i_2)\n",
    "    output_probe = nengo.Probe(b)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Before we train the network the output is approximately zero, since that is the function we specified on the connection from `a` to `b`. However we don't want that output, so we need to train the network to multiply the inputs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_steps = 50\n",
    "minibatch_size = 256\n",
    "# Showing the output of the model pre training\n",
    "with nengo_dl.Simulator(net) as sim:\n",
    "    sim.run_steps(n_steps)\n",
    "    true_value = np.multiply(sim.data[i_1_probe], sim.data[i_2_probe])\n",
    "    fig = plt.figure()\n",
    "    fig.suptitle('Pre-Training')\n",
    "    plt.plot(sim.data[output_probe], 'g', label='predicted value')\n",
    "    plt.plot(true_value, 'm', label='true value')\n",
    "    plt.legend()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To train the network we generate training feeds which consist of two batches of random numbers (the inputs) and then the result of those batches multiplied together (the outputs). Additionally we generate some test data to easily track the progress of the network throughout training. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with nengo_dl.Simulator(net, minibatch_size=minibatch_size) as sim:\n",
    "    # This feed is used as the \"test\" data\n",
    "    # It's run through the network after every iteration \n",
    "    # to allow easy visulization of how the output is changing\n",
    "    test_inputs = {i_1: np.random.uniform(0, 1, size=(minibatch_size, 1, 1)),\n",
    "                   i_2: np.random.uniform(0, 1, size=(minibatch_size, 1, 1))}\n",
    "    test_targets = {output_probe: np.multiply(test_inputs[i_1], test_inputs[i_2])}\n",
    "    \n",
    "    # running through 10 rounds of training/testing\n",
    "    outputs = []\n",
    "    optimizer = tf.train.MomentumOptimizer(5e-2, 0.9)\n",
    "    for i in range(10):\n",
    "        # check performance on test set\n",
    "        sim.step(input_feeds=test_inputs)\n",
    "        print(\"LOSS: \" + str(sim.loss(test_inputs, test_targets, 'mse')))\n",
    "        outputs.append(sim.data[output_probe].flatten())\n",
    "        \n",
    "        # run training\n",
    "        input_feed = {i_1: np.random.uniform(0, 1, size=(minibatch_size*5, 1, 1)),\n",
    "                      i_2: np.random.uniform(0, 1, size=(minibatch_size*5, 1, 1))}\n",
    "        output_feed = {output_probe: np.multiply(input_feed[i_1], input_feed[i_2])}\n",
    "        sim.train(input_feed, output_feed, optimizer, n_epochs=12)\n",
    "        sim.soft_reset(include_probes=True)\n",
    "        \n",
    "    # check final performance on test set\n",
    "    sim.step(input_feeds=test_inputs)\n",
    "    print(\"LOSS: \" + str(sim.loss(test_inputs, test_targets, 'mse')))\n",
    "    outputs.append(sim.data[output_probe].flatten())\n",
    "        "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We visualize the results by plotting the pre-trained, trained and ideal outputs next to each other"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure()\n",
    "fig.suptitle('Pre/Post Training Comparison')\n",
    "plt.plot(outputs[0][:50], 'r', label='pre-training')\n",
    "plt.plot(outputs[10][:50], 'k', label='trained')\n",
    "plt.plot(test_targets[output_probe].flatten()[:50], 'm', label='ideal')\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
